from Model.InferenceModule.inference_module import InferenceModule
import numpy as np
from tianshou.data import Batch
from Model.InferenceModule.module_utils import trace_log_probs

class SinglePassiveModule(InferenceModule):
    def __init__(self, args, extractor, name, forward_dist, single_passive_model, full_models, all_passive_model, all_model):
        super().__init__(args, extractor)
        self.mp = args.inter
        self.name = name
        self.use_active_as_passive = self.mp.use_active_as_passive
        self.use_all_as_single = self.mp.use_all_as_single
        if self.use_all_as_single:
            if self.use_active_as_passive:
                self.model = all_model
            else:
                self.model = all_passive_model
        else:
            if self.use_active_as_passive:
                self.model = full_models
            else:
                self.model = single_passive_model

        self.forward_dist = forward_dist
        self.init_optimizer(args)

    def __call__(self, batch, valid, extractor, normalizer, additional=[], grad_settings=[], log_batch=[], keep_invalid=False, keep_all=False):
        omit_flags = self.get_omit(batch, keep_all=keep_all, keep_invalid=keep_invalid)
        batch = batch[omit_flags]
        if self.use_all_as_single:
            key_state = batch.target
        else: # note that single passive doesn't use the key state. shouldn't hurt to pass it in
            key_state = extractor.get_named_target(batch.target, names=self.name)
        if self.use_active_as_passive or self.use_all_as_single:
            query_state = batch.obs
        else:
            query_state = extractor.get_named_obs(batch.obs, names=self.name)
        valid = valid[omit_flags]

        if self.use_all_as_single:
            mask = np.eye(extractor.num_objects)
        elif self.use_active_as_passive:
            mask = np.zeros(extractor.num_objects)
            mask[extractor.get_index(self.name)] = 1
        else:
            mask = None

        # run the model to get return values
        params, mask, info = self.model(np.concatenate([key_state, query_state], axis=-1), m=mask, valid = valid, dist_settings=['mixed'], ret_settings=additional, grad_settings=grad_settings)
        passive_input, keys, queries, info1,info2 = info
        info = list(zip(info1,info2))
        # if self.use_active_as_passive: params, mask, info = self.model(np.concatenate([key_state, query_state], axis=-1), m=mask, valid = valid, dist_settings=['mixed'], ret_settings=additional)
        # else: params, mask, info = self.model(query_state, valid=valid, ret_settings=additional)
        
        if self.use_all_as_single:
            # assumes that values are of shape [batch, num_keys, ...], so index the appropriate key
            params, mask, info = self._single_index_all(self.name, extractor, params, mask, info)


        target, log_probs = self._target_dists(batch, params)
        result = Batch(target=target, params=params, mask=mask, log_probs=log_probs, omit_flags = omit_flags, passive_input=passive_input)
        result.trace_log_probs = trace_log_probs(extractor.num_objects, result.log_probs, batch, idx=extractor.get_index(self.name))
        for i, aname in enumerate(additional):
            result[aname] = info[i]
        for k in log_batch:
            result[k] = batch[k]
        return result